{
 "cells": [
  {
   "cell_type": "markdown",
   "id": "a3b44703",
   "metadata": {},
   "source": [
    "# Answer correctness prediction with SparkXshards on Orca"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "416b0a4b",
   "metadata": {},
   "source": [
    "Copyright 2016 The BigDL Authors."
   ]
  },
  {
   "cell_type": "raw",
   "id": "a474a628",
   "metadata": {},
   "source": [
    "# Licensed under the Apache License, Version 2.0 (the \"License\");\n",
    "# you may not use this file except in compliance with the License.\n",
    "# You may obtain a copy of the License at\n",
    "#\n",
    "#     http://www.apache.org/licenses/LICENSE-2.0\n",
    "#\n",
    "# Unless required by applicable law or agreed to in writing, software\n",
    "# distributed under the License is distributed on an \"AS IS\" BASIS,\n",
    "# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n",
    "# See the License for the specific language governing permissions and\n",
    "# limitations under the License."
   ]
  },
  {
   "cell_type": "markdown",
   "id": "bf304c91",
   "metadata": {},
   "source": [
    "SparkXshards in Orca allows users to process large-scale dataset using existing Python codes in a distributed and data-parallel fashion, as shown below. This notebook is an example of feature engineering for [LightGBM](https://github.com/intel-analytics/BigDL/blob/main/python/dllib/src/bigdl/dllib/nnframes/tree_model.py) using SparkXshards on Orca. \n",
    "\n",
    "This notebook is adapted from [Riiid! Answer Correctness Prediction EDA. Modeling](https://www.kaggle.com/code/isaienkov/riiid-answer-correctness-prediction-eda-modeling).\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "ebde4a88",
   "metadata": {},
   "outputs": [],
   "source": [
    "# import necessary libraries\n",
    "from bigdl.orca import init_orca_context, stop_orca_context\n",
    "import bigdl.orca.data.pandas\n",
    "import matplotlib.pyplot as plt\n",
    "import pandas as pd\n",
    "import seaborn as sns\n",
    "import warnings\n",
    "warnings.filterwarnings('ignore')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "03eff331",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Start an OrcaContext\n",
    "sc = init_orca_context(memory=\"8g\")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "a46ad6a7",
   "metadata": {},
   "source": [
    "## Load data in parallel and get general information"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "904060e4",
   "metadata": {},
   "source": [
    "Load data into data_shards, it is a SparkXshards that can be operated on in parallel, here each element of the data_shards is a panda dataframe read from a file on the cluster. Users can distribute local code of `pd.read_csv(dataFile)` using `bigdl.orca.data.pandas.read_csv(datapath)`. The full dataset is more than 5G, you can sample a small portion to run the flow fast."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "8702a15b",
   "metadata": {},
   "outputs": [],
   "source": [
    "path = '../answer_correctness/train.csv'\n",
    "used_data_types_list = [\n",
    "    'timestamp',\n",
    "    'user_id',\n",
    "    'content_id',\n",
    "    'answered_correctly',\n",
    "    'prior_question_elapsed_time',\n",
    "    'prior_question_had_explanation'\n",
    "]\n",
    "data_shards = bigdl.orca.data.pandas.read_csv(path,\n",
    "                                             usecols=used_data_types_list,\n",
    "                                             index_col=0)\n",
    "# sample data_shards if needed\n",
    "data_shards = data_shards = data_shards.sample(0.01)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "id": "6d12d912",
   "metadata": {
    "scrolled": true
   },
   "outputs": [
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>timestamp</th>\n",
       "      <th>user_id</th>\n",
       "      <th>content_id</th>\n",
       "      <th>answered_correctly</th>\n",
       "      <th>prior_question_elapsed_time</th>\n",
       "      <th>prior_question_had_explanation</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>1693389</th>\n",
       "      <td>2877721343</td>\n",
       "      <td>35865606</td>\n",
       "      <td>1900</td>\n",
       "      <td>1</td>\n",
       "      <td>26333.0</td>\n",
       "      <td>True</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>1978897</th>\n",
       "      <td>8385893042</td>\n",
       "      <td>41563837</td>\n",
       "      <td>183</td>\n",
       "      <td>1</td>\n",
       "      <td>26666.0</td>\n",
       "      <td>True</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>245353</th>\n",
       "      <td>3720820653</td>\n",
       "      <td>4663746</td>\n",
       "      <td>6152</td>\n",
       "      <td>0</td>\n",
       "      <td>13000.0</td>\n",
       "      <td>True</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>2103373</th>\n",
       "      <td>22086593201</td>\n",
       "      <td>43827287</td>\n",
       "      <td>1113</td>\n",
       "      <td>1</td>\n",
       "      <td>17000.0</td>\n",
       "      <td>True</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>647272</th>\n",
       "      <td>1027906533</td>\n",
       "      <td>13149581</td>\n",
       "      <td>6799</td>\n",
       "      <td>1</td>\n",
       "      <td>57750.0</td>\n",
       "      <td>True</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "</div>"
      ],
      "text/plain": [
       "           timestamp   user_id  content_id  answered_correctly  \\\n",
       "1693389   2877721343  35865606        1900                   1   \n",
       "1978897   8385893042  41563837         183                   1   \n",
       "245353    3720820653   4663746        6152                   0   \n",
       "2103373  22086593201  43827287        1113                   1   \n",
       "647272    1027906533  13149581        6799                   1   \n",
       "\n",
       "         prior_question_elapsed_time prior_question_had_explanation  \n",
       "1693389                      26333.0                           True  \n",
       "1978897                      26666.0                           True  \n",
       "245353                       13000.0                           True  \n",
       "2103373                      17000.0                           True  \n",
       "647272                       57750.0                           True  "
      ]
     },
     "execution_count": 4,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "# show the first couple of rows in the data_shards\n",
    "data_shards.head(5)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "id": "3b8526d9",
   "metadata": {
    "scrolled": true
   },
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "                                                                                \r"
     ]
    },
    {
     "data": {
      "text/plain": [
       "1012301"
      ]
     },
     "execution_count": 5,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "# count total number of rows in the data_shards\n",
    "len(data_shards)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "454939be",
   "metadata": {},
   "source": [
    "## Feature engineering"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "a2639019",
   "metadata": {},
   "source": [
    "Get 90% of the data for training"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "id": "f66319cf",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "                                                                                \r"
     ]
    }
   ],
   "source": [
    "def get_feature(df):\n",
    "    feature_df = df.iloc[:int(9/10 * len(df))]\n",
    "    return feature_df\n",
    "feature_shards = data_shards.transform_shard(get_feature)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "c2c722df",
   "metadata": {},
   "source": [
    "'answered_correctly' is our target, filter data with the correct label. "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "id": "ef7764d8",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "                                                                                \r"
     ]
    }
   ],
   "source": [
    "def get_train_questions_only(df):\n",
    "    train_questions_only_df = df[df['answered_correctly'] != -1]\n",
    "    return train_questions_only_df\n",
    "train_questions_only_shards = feature_shards.transform_shard(get_train_questions_only)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "11caa8ff",
   "metadata": {},
   "source": [
    "Extract statistic features."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "id": "3f30775c",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "                                                                                \r"
     ]
    },
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>user_id</th>\n",
       "      <th>timestamp</th>\n",
       "      <th>content_id</th>\n",
       "      <th>answered_correctly</th>\n",
       "      <th>prior_question_elapsed_time</th>\n",
       "      <th>prior_question_had_explanation</th>\n",
       "      <th>avg(answered_correctly)</th>\n",
       "      <th>count(answered_correctly)</th>\n",
       "      <th>stddev_samp(answered_correctly)</th>\n",
       "      <th>skewness(answered_correctly)</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>0</th>\n",
       "      <td>921758</td>\n",
       "      <td>808038</td>\n",
       "      <td>7218</td>\n",
       "      <td>1</td>\n",
       "      <td>24000.0</td>\n",
       "      <td>False</td>\n",
       "      <td>0.5</td>\n",
       "      <td>2</td>\n",
       "      <td>0.707107</td>\n",
       "      <td>0.0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>1</th>\n",
       "      <td>921758</td>\n",
       "      <td>39606020</td>\n",
       "      <td>6475</td>\n",
       "      <td>0</td>\n",
       "      <td>15000.0</td>\n",
       "      <td>True</td>\n",
       "      <td>0.5</td>\n",
       "      <td>2</td>\n",
       "      <td>0.707107</td>\n",
       "      <td>0.0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>2</th>\n",
       "      <td>1263038</td>\n",
       "      <td>55294</td>\n",
       "      <td>7963</td>\n",
       "      <td>1</td>\n",
       "      <td>16000.0</td>\n",
       "      <td>False</td>\n",
       "      <td>1.0</td>\n",
       "      <td>2</td>\n",
       "      <td>0.000000</td>\n",
       "      <td>NaN</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>3</th>\n",
       "      <td>1263038</td>\n",
       "      <td>438857</td>\n",
       "      <td>1249</td>\n",
       "      <td>1</td>\n",
       "      <td>14000.0</td>\n",
       "      <td>True</td>\n",
       "      <td>1.0</td>\n",
       "      <td>2</td>\n",
       "      <td>0.000000</td>\n",
       "      <td>NaN</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>4</th>\n",
       "      <td>2706752</td>\n",
       "      <td>685394944</td>\n",
       "      <td>5669</td>\n",
       "      <td>1</td>\n",
       "      <td>17000.0</td>\n",
       "      <td>True</td>\n",
       "      <td>1.0</td>\n",
       "      <td>1</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "</div>"
      ],
      "text/plain": [
       "   user_id  timestamp  content_id  answered_correctly  \\\n",
       "0   921758     808038        7218                   1   \n",
       "1   921758   39606020        6475                   0   \n",
       "2  1263038      55294        7963                   1   \n",
       "3  1263038     438857        1249                   1   \n",
       "4  2706752  685394944        5669                   1   \n",
       "\n",
       "   prior_question_elapsed_time prior_question_had_explanation  \\\n",
       "0                      24000.0                          False   \n",
       "1                      15000.0                           True   \n",
       "2                      16000.0                          False   \n",
       "3                      14000.0                           True   \n",
       "4                      17000.0                           True   \n",
       "\n",
       "   avg(answered_correctly)  count(answered_correctly)  \\\n",
       "0                      0.5                          2   \n",
       "1                      0.5                          2   \n",
       "2                      1.0                          2   \n",
       "3                      1.0                          2   \n",
       "4                      1.0                          1   \n",
       "\n",
       "   stddev_samp(answered_correctly)  skewness(answered_correctly)  \n",
       "0                         0.707107                           0.0  \n",
       "1                         0.707107                           0.0  \n",
       "2                         0.000000                           NaN  \n",
       "3                         0.000000                           NaN  \n",
       "4                              NaN                           NaN  "
      ]
     },
     "execution_count": 8,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "train_questions_only_shards = \\\n",
    "    train_questions_only_shards.group_by(columns='user_id', agg={\"answered_correctly\": ['mean',\n",
    "                                                                                       'count',\n",
    "                                                                                       'stddev',\n",
    "                                                                                       'skewness']\n",
    "                                                                }, join=True)\n",
    "train_questions_only_shards.head(5)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "0e6ca8f7",
   "metadata": {},
   "source": [
    "Fill None values with value of 0.5, and transform binary feature of 'prior_question_had_explanation' as integer"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "id": "decd664b",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "                                                                                \r"
     ]
    }
   ],
   "source": [
    "def transform(df, val):\n",
    "    train_df = df.fillna(value=val)\n",
    "    train_df['prior_question_had_explanation'] = train_df['prior_question_had_explanation'].astype(int)\n",
    "    return train_df\n",
    "train_shards = train_questions_only_shards.transform_shard(transform, 0.5)\n"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "be54ad71",
   "metadata": {},
   "source": [
    "Assembly 'features' and rename label columns to 'label'"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "id": "e6d0770f",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "                                                                                \r"
     ]
    },
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>features</th>\n",
       "      <th>label</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>0</th>\n",
       "      <td>[24000.0, 0.0, 0.5, 2.0, 0.7071067811865476, 0.0]</td>\n",
       "      <td>1</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>1</th>\n",
       "      <td>[15000.0, 1.0, 0.5, 2.0, 0.7071067811865476, 0.0]</td>\n",
       "      <td>0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>2</th>\n",
       "      <td>[16000.0, 0.0, 1.0, 2.0, 0.0, 0.5]</td>\n",
       "      <td>1</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>3</th>\n",
       "      <td>[14000.0, 1.0, 1.0, 2.0, 0.0, 0.5]</td>\n",
       "      <td>1</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>4</th>\n",
       "      <td>[17000.0, 1.0, 1.0, 1.0, 0.5, 0.5]</td>\n",
       "      <td>1</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "</div>"
      ],
      "text/plain": [
       "                                            features  label\n",
       "0  [24000.0, 0.0, 0.5, 2.0, 0.7071067811865476, 0.0]      1\n",
       "1  [15000.0, 1.0, 0.5, 2.0, 0.7071067811865476, 0.0]      0\n",
       "2                 [16000.0, 0.0, 1.0, 2.0, 0.0, 0.5]      1\n",
       "3                 [14000.0, 1.0, 1.0, 2.0, 0.0, 0.5]      1\n",
       "4                 [17000.0, 1.0, 1.0, 1.0, 0.5, 0.5]      1"
      ]
     },
     "execution_count": 10,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "feature_cols = ['prior_question_elapsed_time', 'prior_question_had_explanation', 'avg(answered_correctly)', 'count(answered_correctly)', 'stddev_samp(answered_correctly)', 'skewness(answered_correctly)']\n",
    "\n",
    "def assembly(df):\n",
    "    y = df.rename({'answered_correctly': 'label'}, axis=1)['label']\n",
    "    df['features'] = df[feature_cols].values.tolist()\n",
    "    df = df['features']\n",
    "    df1 = pd.concat([df, y], axis=1)\n",
    "    return df1\n",
    "\n",
    "train_shards = train_shards.transform_shard(assembly)\n",
    "train_shards.head(5)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "eae67300",
   "metadata": {},
   "source": [
    "Current LightGBM requires training data as Spark dataframe with a feature column of pyspark `DenseVector`, transform SparkXshards to Spark dataframe required."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "id": "52a186c3",
   "metadata": {},
   "outputs": [],
   "source": [
    "from pyspark.ml.linalg import DenseVector, VectorUDT\n",
    "from pyspark.sql.functions import udf, col\n",
    "sdf = train_shards.to_spark_df()\n",
    "sdf = sdf.withColumn('features', udf(lambda x: DenseVector(x), VectorUDT())(col('features')))"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "ea48ca52",
   "metadata": {},
   "source": [
    "## Train a LightGBM model"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "d5f4357d",
   "metadata": {},
   "source": [
    "Build a LightGBMClassifier"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "id": "e3be1983",
   "metadata": {},
   "outputs": [],
   "source": [
    "from bigdl.dllib.nnframes.tree_model import LightGBMClassifier, LightGBMClassifierModel\n",
    "\n",
    "params = {\"boosting_type\": \"gbdt\", \"num_leaves\": 70, \"learning_rate\": 0.3,\n",
    "          \"min_data_in_leaf\": 20, \"objective\": \"binary\",\n",
    "          'num_iterations': 1000,\n",
    "          'max_depth': 14,\n",
    "          'lambda_l1': 0.01,\n",
    "          'lambda_l2': 0.01,\n",
    "          'bagging_freq': 5,\n",
    "          'max_bin': 255,\n",
    "          'early_stopping_round': 20\n",
    "          }\n",
    "\n",
    "estimator = LightGBMClassifier(params)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "3c28ed25",
   "metadata": {},
   "source": [
    "Train a model"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "id": "504f7150",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "[Stage 50:>                                                         (0 + 2) / 2]\n",
      "User settings:\n",
      "\n",
      "   KMP_AFFINITY=granularity=fine,compact,1,0\n",
      "   KMP_BLOCKTIME=0\n",
      "   KMP_SETTINGS=1\n",
      "   OMP_NUM_THREADS=1\n",
      "\n",
      "Effective settings:\n",
      "\n",
      "   KMP_ABORT_DELAY=0\n",
      "   KMP_ADAPTIVE_LOCK_PROPS='1,1024'\n",
      "   KMP_ALIGN_ALLOC=64\n",
      "   KMP_ALL_THREADPRIVATE=128\n",
      "   KMP_ATOMIC_MODE=2\n",
      "   KMP_BLOCKTIME=0\n",
      "   KMP_DETERMINISTIC_REDUCTION=false\n",
      "   KMP_DEVICE_THREAD_LIMIT=2147483647\n",
      "   KMP_DISP_NUM_BUFFERS=7\n",
      "   KMP_DUPLICATE_LIB_OK=false\n",
      "   KMP_ENABLE_TASK_THROTTLING=true\n",
      "   KMP_FORCE_MONOTONIC_DYNAMIC_SCHEDULE=false\n",
      "   KMP_FORCE_REDUCTION: value is not defined\n",
      "   KMP_FOREIGN_THREADS_THREADPRIVATE=true\n",
      "   KMP_FORKJOIN_BARRIER='2,2'\n",
      "   KMP_FORKJOIN_BARRIER_PATTERN='hyper,hyper'\n",
      "   KMP_FORKJOIN_FRAMES=true\n",
      "   KMP_FORKJOIN_FRAMES_MODE=3\n",
      "   KMP_GTID_MODE=0\n",
      "   KMP_HANDLE_SIGNALS=false\n",
      "   KMP_HOT_TEAMS_MAX_LEVEL=1\n",
      "   KMP_HOT_TEAMS_MODE=0\n",
      "   KMP_INIT_AT_FORK=true\n",
      "   KMP_ITT_PREPARE_DELAY=0\n",
      "   KMP_LIBRARY=throughput\n",
      "   KMP_LOCK_KIND=queuing\n",
      "   KMP_MALLOC_POOL_INCR=1M\n",
      "   KMP_NESTING_MODE=0\n",
      "   KMP_NUM_LOCKS_IN_BLOCK=1\n",
      "   KMP_PLAIN_BARRIER='2,2'\n",
      "   KMP_PLAIN_BARRIER_PATTERN='hyper,hyper'\n",
      "   KMP_REDUCTION_BARRIER='1,1'\n",
      "   KMP_REDUCTION_BARRIER_PATTERN='hyper,hyper'\n",
      "   KMP_SCHEDULE='static,balanced;guided,iterative'\n",
      "   KMP_SETTINGS=true\n",
      "   KMP_SPIN_BACKOFF_PARAMS='4096,100'\n",
      "   KMP_STACKOFFSET=0\n",
      "   KMP_STACKPAD=0\n",
      "   KMP_STACKSIZE=8M\n",
      "   KMP_STORAGE_MAP=false\n",
      "   KMP_TASKING=2\n",
      "   KMP_TASKLOOP_MIN_TASKS=0\n",
      "   KMP_TASK_STEALING_CONSTRAINT=1\n",
      "   KMP_TEAMS_THREAD_LIMIT=8\n",
      "   KMP_USE_YIELD=1\n",
      "   KMP_VERSION=false\n",
      "   KMP_WARNINGS=true\n",
      "   LIBOMP_NUM_HIDDEN_HELPER_THREADS=0\n",
      "   LIBOMP_USE_HIDDEN_HELPER_TASK=false\n",
      "   OMP_AFFINITY_FORMAT='OMP: pid %P tid %i thread %n bound to OS proc set {%A}'\n",
      "   OMP_ALLOCATOR=omp_default_mem_alloc\n",
      "   OMP_CANCELLATION=false\n",
      "   OMP_DEFAULT_DEVICE=0\n",
      "   OMP_DISPLAY_AFFINITY=false\n",
      "   OMP_DISPLAY_ENV=false\n",
      "   OMP_DYNAMIC=false\n",
      "   OMP_MAX_ACTIVE_LEVELS=1\n",
      "   OMP_MAX_TASK_PRIORITY=0\n",
      "   OMP_NESTED: deprecated; max-active-levels-var=1\n",
      "   OMP_NUM_TEAMS=0\n",
      "   OMP_NUM_THREADS='1'\n",
      "   OMP_PROC_BIND='false'\n",
      "   OMP_SCHEDULE='static'\n",
      "   OMP_STACKSIZE=8M\n",
      "   OMP_TARGET_OFFLOAD=DEFAULT\n",
      "   OMP_TEAMS_THREAD_LIMIT=0\n",
      "   OMP_THREAD_LIMIT=2147483647\n",
      "   OMP_TOOL=enabled\n",
      "   OMP_TOOL_LIBRARIES: value is not defined\n",
      "   OMP_TOOL_VERBOSE_INIT: value is not defined\n",
      "   OMP_WAIT_POLICY=PASSIVE\n",
      "\n",
      "                                                                                \r"
     ]
    }
   ],
   "source": [
    "model = estimator.fit(sdf)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "5121fe9f",
   "metadata": {},
   "source": [
    "Make prediction"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "id": "0ff2a30f",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "2022-11-23 23:11:40 WARN  DAGScheduler:69 - Broadcasting large task binary with size 7.2 MiB\n",
      "+--------------------------------------------+-----+------------------------------------------+-------------------------------------------+----------+\n",
      "|features                                    |label|rawPrediction                             |probability                                |prediction|\n",
      "+--------------------------------------------+-----+------------------------------------------+-------------------------------------------+----------+\n",
      "|[24000.0,0.0,0.5,2.0,0.7071067811865476,0.0]|1    |[0.18786715530061926,-0.18786715530061926]|[0.5468291372132481,0.45317086278675195]   |0.0       |\n",
      "|[15000.0,1.0,0.5,2.0,0.7071067811865476,0.0]|0    |[-0.22410400024965002,0.22410400024965002]|[0.44420730923050533,0.5557926907694947]   |1.0       |\n",
      "|[16000.0,0.0,1.0,2.0,0.0,0.5]               |1    |[-27.236102957272454,27.236102957272454]  |[1.4843681839238343E-12,0.9999999999985156]|1.0       |\n",
      "|[14000.0,1.0,1.0,2.0,0.0,0.5]               |1    |[-19.275062351695446,19.275062351695446]  |[4.2554626489277325E-9,0.9999999957445374] |1.0       |\n",
      "|[17000.0,1.0,1.0,1.0,0.5,0.5]               |1    |[-18.481993748724157,18.481993748724157]  |[9.40528799286966E-9,0.999999990594712]    |1.0       |\n",
      "+--------------------------------------------+-----+------------------------------------------+-------------------------------------------+----------+\n",
      "only showing top 5 rows\n",
      "\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "\r",
      "[Stage 53:>                                                         (0 + 1) / 1]\r",
      "\r",
      "                                                                                \r"
     ]
    }
   ],
   "source": [
    "predictions = model.transform(sdf)\n",
    "predictions.show(5, False)"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "py37tf2_x3",
   "language": "python",
   "name": "py37tf2_x3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.10"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
